{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Collaborative Filtering with Neural Networks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook we will write a matrix factorization model in pytorch to solve a recommendation problem. Then we will write a more general neural model for the same problem."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The MovieLens dataset (ml-latest-small) describes 5-star rating and free-text tagging activity from MovieLens, a movie recommendation service. It contains 100004 ratings and 1296 tag applications across 9125 movies. https://grouplens.org/datasets/movielens/. To get the data:\n",
    "\n",
    "`wget http://files.grouplens.org/datasets/movielens/ml-latest-small.zip`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MovieLens dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[PosixPath('/data2/yinterian/ml-latest-small/ratings.csv'),\n",
       " PosixPath('/data2/yinterian/ml-latest-small/tags.csv'),\n",
       " PosixPath('/data2/yinterian/ml-latest-small/tiny_training2.csv'),\n",
       " PosixPath('/data2/yinterian/ml-latest-small/links.csv'),\n",
       " PosixPath('/data2/yinterian/ml-latest-small/tiny_val2.csv'),\n",
       " PosixPath('/data2/yinterian/ml-latest-small/README.txt'),\n",
       " PosixPath('/data2/yinterian/ml-latest-small/movies.csv')]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "PATH = Path(\"/Users/yinterian/teaching/deeplearning/data/ml-latest-small/\")\n",
    "PATH = Path(\"/data2/yinterian/ml-latest-small/\")\n",
    "list(PATH.iterdir())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "userId,movieId,rating,timestamp\r",
      "\r\n",
      "1,31,2.5,1260759144\r",
      "\r\n",
      "1,1029,3.0,1260759179\r",
      "\r\n",
      "1,1061,3.0,1260759182\r",
      "\r\n",
      "1,1129,2.0,1260759185\r",
      "\r\n",
      "1,1172,4.0,1260759205\r",
      "\r\n",
      "1,1263,2.0,1260759151\r",
      "\r\n",
      "1,1287,2.0,1260759187\r",
      "\r\n",
      "1,1293,2.0,1260759148\r",
      "\r\n",
      "1,1339,3.5,1260759125\r",
      "\r\n"
     ]
    }
   ],
   "source": [
    "! head $PATH/ratings.csv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(PATH/\"ratings.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userId</th>\n",
       "      <th>movieId</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>31</td>\n",
       "      <td>2.5</td>\n",
       "      <td>1260759144</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1029</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1260759179</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>1061</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1260759182</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>1129</td>\n",
       "      <td>2.0</td>\n",
       "      <td>1260759185</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>1172</td>\n",
       "      <td>4.0</td>\n",
       "      <td>1260759205</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   userId  movieId  rating   timestamp\n",
       "0       1       31     2.5  1260759144\n",
       "1       1     1029     3.0  1260759179\n",
       "2       1     1061     3.0  1260759182\n",
       "3       1     1129     2.0  1260759185\n",
       "4       1     1172     4.0  1260759205"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Encoding data\n",
    "We enconde the data to have contiguous ids for users and movies. You can think about this as a categorical encoding of our two categorical variables userId and movieId."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split train and validation before encoding\n",
    "np.random.seed(3)\n",
    "msk = np.random.rand(len(data)) < 0.8\n",
    "train = data[msk].copy()\n",
    "val = data[~msk].copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# here is a handy function modified from fast.ai\n",
    "def proc_col(col, train_col=None):\n",
    "    \"\"\"Encodes a pandas column with continous ids. \n",
    "    \"\"\"\n",
    "    if train_col is not None:\n",
    "        uniq = train_col.unique()\n",
    "    else:\n",
    "        uniq = col.unique()\n",
    "    name2idx = {o:i for i,o in enumerate(uniq)}\n",
    "    return name2idx, np.array([name2idx.get(x, -1) for x in col]), len(uniq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_data(df, train=None):\n",
    "    \"\"\" Encodes rating data with continous user and movie ids. \n",
    "    If train is provided, encodes df with the same encoding as train.\n",
    "    \"\"\"\n",
    "    df = df.copy()\n",
    "    for col_name in [\"userId\", \"movieId\"]:\n",
    "        train_col = None\n",
    "        if train is not None:\n",
    "            train_col = train[col_name]\n",
    "        _,col,_ = proc_col(df[col_name], train_col)\n",
    "        df[col_name] = col\n",
    "        df = df[df[col_name] >= 0]\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    userId  movieId  rating\n",
      "0       11        1       4\n",
      "1       11       23       5\n",
      "2        2       23       5\n",
      "3        2        4       3\n",
      "4       31        1       4\n",
      "5       31       23       4\n",
      "6        4        1       5\n",
      "7        4        3       2\n",
      "8       52        1       1\n",
      "9       52        3       4\n",
      "10      61        3       5\n",
      "11       7       23       1\n",
      "12       7        3       3\n",
      "    userId  movieId  rating\n",
      "0        0        0       4\n",
      "1        0        1       5\n",
      "2        1        1       5\n",
      "3        1        2       3\n",
      "4        2        0       4\n",
      "5        2        1       4\n",
      "6        3        0       5\n",
      "7        3        3       2\n",
      "8        4        0       1\n",
      "9        4        3       4\n",
      "10       5        3       5\n",
      "11       6        1       1\n",
      "12       6        3       3\n"
     ]
    }
   ],
   "source": [
    "# to check my new implementation\n",
    "LOCAL_PATH = Path(\"images/\")\n",
    "df_t = pd.read_csv(LOCAL_PATH/\"tiny_training2.csv\")\n",
    "df_v = pd.read_csv(LOCAL_PATH/\"tiny_val2.csv\")\n",
    "print(df_t)\n",
    "df_t_e = encode_data(df_t)\n",
    "df_v_e = encode_data(df_v, df_t)\n",
    "df_v_e\n",
    "print(df_t_e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# encoding the train and validation data\n",
    "df_train = encode_data(train)\n",
    "df_val = encode_data(val, train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Embedding layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# an Embedding module containing 10 user or item embedding size 3\n",
    "# embedding will be initialized at random\n",
    "embed = nn.Embedding(10, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-0.1301,  0.0691, -1.1678],\n",
       "         [-0.9865,  0.4514, -1.4770],\n",
       "         [-1.7121,  0.0701,  0.0481],\n",
       "         [ 1.4485,  0.1340,  0.0099],\n",
       "         [-1.4074, -0.8650, -0.1255],\n",
       "         [-0.1301,  0.0691, -1.1678]]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# given a list of ids we can \"look up\" the embedding corresponing to each id\n",
    "a = torch.LongTensor([[1,2,0,4,5,1]])\n",
    "embed(a)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Matrix factorization model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MF(nn.Module):\n",
    "    def __init__(self, num_users, num_items, emb_size=100):\n",
    "        super(MF, self).__init__()\n",
    "        self.user_emb = nn.Embedding(num_users, emb_size)\n",
    "        self.item_emb = nn.Embedding(num_items, emb_size)\n",
    "        self.user_emb.weight.data.uniform_(0, 0.05)\n",
    "        self.item_emb.weight.data.uniform_(0, 0.05)\n",
    "        \n",
    "    def forward(self, u, v):\n",
    "        u = self.user_emb(u)\n",
    "        v = self.item_emb(v)\n",
    "        return (u*v).sum(1)   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Debugging MF model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>userId</th>\n",
       "      <th>movieId</th>\n",
       "      <th>rating</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>4</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>5</td>\n",
       "      <td>3</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>6</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>6</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    userId  movieId  rating\n",
       "0        0        0       4\n",
       "1        0        1       5\n",
       "2        1        1       5\n",
       "3        1        2       3\n",
       "4        2        0       4\n",
       "5        2        1       4\n",
       "6        3        0       5\n",
       "7        3        3       2\n",
       "8        4        0       1\n",
       "9        4        3       4\n",
       "10       5        3       5\n",
       "11       6        1       1\n",
       "12       6        3       3"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_t_e"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_users = 7\n",
    "num_items = 4\n",
    "emb_size = 3\n",
    "\n",
    "user_emb = nn.Embedding(num_users, emb_size)\n",
    "item_emb = nn.Embedding(num_items, emb_size)\n",
    "users = torch.LongTensor(df_t_e.userId.values)\n",
    "items = torch.LongTensor(df_t_e.movieId.values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "U = user_emb(users)\n",
    "V = item_emb(items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.1547,  0.2277,  0.2442],\n",
       "        [ 0.1547,  0.2277,  0.2442],\n",
       "        [ 0.6601,  0.8225, -1.2139],\n",
       "        [ 0.6601,  0.8225, -1.2139],\n",
       "        [ 0.1672, -1.2177,  0.1403],\n",
       "        [ 0.1672, -1.2177,  0.1403],\n",
       "        [-1.1907, -1.2933, -0.5506],\n",
       "        [-1.1907, -1.2933, -0.5506],\n",
       "        [ 0.1938, -0.0683, -0.8493],\n",
       "        [ 0.1938, -0.0683, -0.8493],\n",
       "        [ 0.8506, -1.1564,  1.1165],\n",
       "        [ 0.8639, -2.5148, -0.8391],\n",
       "        [ 0.8639, -2.5148, -0.8391]])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "U"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1766,  0.2957,  0.4409],\n",
       "        [ 0.1205,  0.1733,  0.1165],\n",
       "        [ 0.5143,  0.6258, -0.5793],\n",
       "        [-0.5603,  0.3582, -0.5370],\n",
       "        [-0.1909, -1.5812,  0.2533],\n",
       "        [ 0.1303, -0.9266,  0.0670],\n",
       "        [ 1.3594, -1.6793, -0.9940],\n",
       "        [-0.2324,  1.4822,  0.5151],\n",
       "        [-0.2212, -0.0887, -1.5335],\n",
       "        [ 0.0378,  0.0783,  0.7947],\n",
       "        [ 0.1660,  1.3253, -1.0447],\n",
       "        [ 0.6730, -1.9135, -0.4004],\n",
       "        [ 0.1686,  2.8820,  0.7851]])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# element wise multiplication\n",
    "U*V "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0.5600,  0.4103,  0.5608, -0.7391, -1.5187, -0.7294, -1.3139,\n",
       "         1.7649, -1.8434,  0.9108,  0.4466, -1.6409,  3.8357])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# what we want is a dot product per row\n",
    "(U*V).sum(1) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training MF model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "671 8442\n"
     ]
    }
   ],
   "source": [
    "num_users = len(df_train.userId.unique())\n",
    "num_items = len(df_train.movieId.unique())\n",
    "print(num_users, num_items) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MF(num_users, num_items, emb_size=100) # .cuda() if you have a GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epocs(model, epochs=10, lr=0.01, wd=0.0, unsqueeze=False):\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
    "    model.train()\n",
    "    for i in range(epochs):\n",
    "        users = torch.LongTensor(df_train.userId.values) # .cuda()\n",
    "        items = torch.LongTensor(df_train.movieId.values) #.cuda()\n",
    "        ratings = torch.FloatTensor(df_train.rating.values) #.cuda()\n",
    "        if unsqueeze:\n",
    "            ratings = ratings.unsqueeze(1)\n",
    "        y_hat = model(users, items)\n",
    "        loss = F.mse_loss(y_hat, ratings)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        print(loss.item()) \n",
    "    test_loss(model, unsqueeze)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([79799])\n",
      "torch.Size([79799, 1])\n"
     ]
    }
   ],
   "source": [
    "# Here is what unsqueeze does\n",
    "ratings = torch.FloatTensor(df_train.rating.values)\n",
    "print(ratings.shape)\n",
    "ratings = ratings.unsqueeze(1) # .cuda()\n",
    "print(ratings.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_loss(model, unsqueeze=False):\n",
    "    model.eval()\n",
    "    users = torch.LongTensor(df_val.userId.values) #.cuda()\n",
    "    items = torch.LongTensor(df_val.movieId.values) #.cuda()\n",
    "    ratings = torch.FloatTensor(df_val.rating.values) #.cuda()\n",
    "    if unsqueeze:\n",
    "        ratings = ratings.unsqueeze(1)\n",
    "    y_hat = model(users, items)\n",
    "    loss = F.mse_loss(y_hat, ratings)\n",
    "    print(\"test loss %.3f \" % loss.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13.23068904876709\n",
      "5.119534015655518\n",
      "2.3902299404144287\n",
      "3.441521406173706\n",
      "0.9096018671989441\n",
      "1.8109439611434937\n",
      "2.749631643295288\n",
      "2.278921604156494\n",
      "1.1593214273452759\n",
      "0.925656795501709\n",
      "test loss 1.947 \n"
     ]
    }
   ],
   "source": [
    "train_epocs(model, epochs=10, lr=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.7027523517608643\n",
      "1.0512956380844116\n",
      "0.7498359680175781\n",
      "0.6950282454490662\n",
      "0.7596880197525024\n",
      "0.8397833108901978\n",
      "0.8818210363388062\n",
      "0.8753886818885803\n",
      "0.8334189653396606\n",
      "0.7767009735107422\n",
      "0.7246581315994263\n",
      "0.6901594400405884\n",
      "0.6771144866943359\n",
      "0.6810137033462524\n",
      "0.69219970703125\n",
      "test loss 0.894 \n"
     ]
    }
   ],
   "source": [
    "train_epocs(model, epochs=15, lr=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7007282376289368\n",
      "0.6625022888183594\n",
      "0.6684340834617615\n",
      "0.6455244421958923\n",
      "0.6380830407142639\n",
      "0.6450700759887695\n",
      "0.6408411264419556\n",
      "0.6256920099258423\n",
      "0.6144804358482361\n",
      "0.6132143139839172\n",
      "0.6140048503875732\n",
      "0.6083489060401917\n",
      "0.5969548225402832\n",
      "0.5860226154327393\n",
      "0.5791704058647156\n",
      "test loss 0.822 \n"
     ]
    }
   ],
   "source": [
    "train_epocs(model, epochs=15, lr=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MF with bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MF_bias(nn.Module):\n",
    "    def __init__(self, num_users, num_items, emb_size=100):\n",
    "        super(MF_bias, self).__init__()\n",
    "        self.user_emb = nn.Embedding(num_users, emb_size)\n",
    "        self.user_bias = nn.Embedding(num_users, 1)\n",
    "        self.item_emb = nn.Embedding(num_items, emb_size)\n",
    "        self.item_bias = nn.Embedding(num_items, 1)\n",
    "        self.user_emb.weight.data.uniform_(0,0.05)\n",
    "        self.item_emb.weight.data.uniform_(0,0.05)\n",
    "        self.user_bias.weight.data.uniform_(-0.01,0.01)\n",
    "        self.item_bias.weight.data.uniform_(-0.01,0.01)\n",
    "        \n",
    "    def forward(self, u, v):\n",
    "        U = self.user_emb(u)\n",
    "        V = self.item_emb(v)\n",
    "        b_u = self.user_bias(u).squeeze()\n",
    "        b_v = self.item_bias(v).squeeze()\n",
    "        return (U*V).sum(1) +  b_u  + b_v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MF_bias(num_users, num_items, emb_size=100) #.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13.233644485473633\n",
      "9.459980964660645\n",
      "4.618295669555664\n",
      "1.2266862392425537\n",
      "2.4537320137023926\n",
      "3.888521432876587\n",
      "2.6157896518707275\n",
      "1.1573508977890015\n",
      "0.8204843997955322\n",
      "1.3100122213363647\n",
      "test loss 2.126 \n"
     ]
    }
   ],
   "source": [
    "train_epocs(model, epochs=10, lr=0.05, wd=1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.9130752086639404\n",
      "1.3447301387786865\n",
      "0.9572998285293579\n",
      "0.7714419364929199\n",
      "0.752704381942749\n",
      "0.8091325759887695\n",
      "0.8543495535850525\n",
      "0.8524782657623291\n",
      "0.8114585876464844\n",
      "0.7577651739120483\n",
      "test loss 0.851 \n"
     ]
    }
   ],
   "source": [
    "train_epocs(model, epochs=10, lr=0.01, wd=1e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7163214087486267\n",
      "0.7023102045059204\n",
      "0.6904919147491455\n",
      "0.6807348728179932\n",
      "0.6728458404541016\n",
      "0.6666097044944763\n",
      "0.6618107557296753\n",
      "0.6582220792770386\n",
      "0.6556380391120911\n",
      "0.6538312435150146\n",
      "test loss 0.805 \n"
     ]
    }
   ],
   "source": [
    "train_epocs(model, epochs=10, lr=0.001, wd=1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that these models are susceptible to weight initialization, optimization algorithm and regularization."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Neural Network Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note here there is no matrix multiplication, we could potentially make the embeddings of different sizes.\n",
    "# Here we could get better results by keep playing with regularization.\n",
    "    \n",
    "class CollabFNet(nn.Module):\n",
    "    def __init__(self, num_users, num_items, emb_size=100, n_hidden=10):\n",
    "        super(CollabFNet, self).__init__()\n",
    "        self.user_emb = nn.Embedding(num_users, emb_size)\n",
    "        self.item_emb = nn.Embedding(num_items, emb_size)\n",
    "        self.lin1 = nn.Linear(emb_size*2, n_hidden)\n",
    "        self.lin2 = nn.Linear(n_hidden, 1)\n",
    "        self.drop1 = nn.Dropout(0.1)\n",
    "        \n",
    "    def forward(self, u, v):\n",
    "        U = self.user_emb(u)\n",
    "        V = self.item_emb(v)\n",
    "        x = F.relu(torch.cat([U, V], dim=1))\n",
    "        x = self.drop1(x)\n",
    "        x = F.relu(self.lin1(x))\n",
    "        x = self.lin2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CollabFNet(num_users, num_items, emb_size=100) #.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13.101761817932129\n",
      "1.957230806350708\n",
      "1.2605514526367188\n",
      "1.3381402492523193\n",
      "1.061022162437439\n",
      "1.1385098695755005\n",
      "0.9165319800376892\n",
      "0.9622549414634705\n",
      "0.8723138570785522\n",
      "0.8084518909454346\n",
      "0.8500765562057495\n",
      "0.7535637617111206\n",
      "0.791947603225708\n",
      "0.7653028964996338\n",
      "0.7301635146141052\n",
      "test loss 0.869 \n"
     ]
    }
   ],
   "source": [
    "train_epocs(model, epochs=15, lr=0.05, wd=1e-6, unsqueeze=True) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7691234350204468\n",
      "0.9072751402854919\n",
      "0.7757670879364014\n",
      "0.7180655598640442\n",
      "0.7918605208396912\n",
      "0.7724899053573608\n",
      "0.7119362950325012\n",
      "0.7106000185012817\n",
      "0.7403213977813721\n",
      "0.7438958883285522\n",
      "test loss 0.816 \n"
     ]
    }
   ],
   "source": [
    "train_epocs(model, epochs=10, lr=0.01, wd=1e-6, unsqueeze=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7163267731666565\n",
      "0.7032808065414429\n",
      "0.695513904094696\n",
      "0.6967512369155884\n",
      "0.6998187303543091\n",
      "0.700666606426239\n",
      "0.7004959583282471\n",
      "0.6982167959213257\n",
      "0.6955875158309937\n",
      "0.694402813911438\n",
      "test loss 0.796 \n"
     ]
    }
   ],
   "source": [
    "train_epocs(model, epochs=10, lr=0.001, wd=1e-6, unsqueeze=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6919353008270264\n",
      "0.6934647560119629\n",
      "0.6922585368156433\n",
      "0.6942275762557983\n",
      "0.6926798224449158\n",
      "0.6916202902793884\n",
      "0.6911264061927795\n",
      "0.6923496127128601\n",
      "0.6922929286956787\n",
      "0.6904215812683105\n",
      "test loss 0.795 \n"
     ]
    }
   ],
   "source": [
    "train_epocs(model, epochs=10, lr=0.001, wd=1e-6, unsqueeze=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# References\n",
    "* This notebook is based on [lesson 5 of Jeremy Howard's Deep Learning Course](https://github.com/fastai/fastai/blob/master/courses/dl1/lesson5-movielens.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
